sklearn中cross

您所在的位置:网站首页 fit into的用法 sklearn中cross

sklearn中cross

#sklearn中cross| 来源: 网络整理| 查看: 265

交叉验证的概念,直接粘贴scikit-learn官网的定义:

scikit-learn中计算交叉验证的函数:

cross_val_score:得到K折验证中每一折的得分,K个得分取平均值就是模型的平均性能

cross_val_predict:得到经过K折交叉验证计算得到的每个训练验证的输出预测

方法:

cross_val_score:分别在K-1折上训练模型,在余下的1折上验证模型,并保存余下1折中的预测得分

cross_val_predict:分别在K-1上训练模型,在余下的1折上验证模型,并将余下1折中样本的预测输出作为最终输出结果的一部分

结论:

cross_val_score计算得到的平均性能可以作为模型的泛化性能参考

cross_val_predict计算得到的样本预测输出不能作为模型的泛化性能参考

代码样例:

from sklearn import datasets import numpy as np from sklearn.tree import DecisionTreeClassifier from sklearn import datasets import numpy as np from sklearn.tree import DecisionTreeClassifier # 加载鸢尾花数据集 iris = datasets.load_iris() iris_train = iris.data iris_target = iris.target print(iris_train.shape) print(iris_target.shape) (150, 4) (150,) # 构建决策树分类模型 tree_clf = DecisionTreeClassifier() tree_clf.fit(iris_train, iris_target) tree_predict = tree_clf.predict(iris_train) ​ # 计算决策树分类模型的准确率 from sklearn.metrics import accuracy_score print("Accuracy:", accuracy_score(iris_target, tree_predict)) Accuracy: 1.0 # 交叉验证cross_val_score输出每一折上的准确率 from sklearn.model_selection import cross_val_predict, cross_val_score, cross_validate tree_scores = cross_val_score(tree_clf, iris_train, iris_target, cv=3) print(tree_scores) [0.98039216 0.92156863 1. ] # 交叉验证cross_val_predict输出每个样本的预测结果 tree_predict = cross_val_predict(tree_clf, iris_train, iris_target, cv=3) print(tree_predict) print(len(tree_predict)) print(accuracy_score(iris_target, tree_predict)) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2] 150 0.96 print(tree_clf.predict(iris_train)) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] # 交叉验证cross_validate对cross_val_score结果进行包装,并包含fit的时间等信息 tree_val = cross_validate(tree_clf, iris_train, iris_target, cv=3) print(tree_val) {'fit_time': array([0., 0., 0.]), 'score_time': array([0., 0., 0.]), 'test_score': array([0.98039216, 0.92156863, 0.97916667])} ​ ​

交叉验证评价方式scoring的参数链接:3.3. Metrics and scoring: quantifying the quality of predictions — scikit-learn 1.0.2 documentation



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3